{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Implementation of Multilayer Perceptron from Scratch\n",
    "\n",
    ":label:`sec_mlp_scratch`\n",
    "\n",
    "\n",
    "Now that we have characterized \n",
    "multilayer perceptrons (MLPs) mathematically, \n",
    "let us try to implement one ourselves."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports\n",
    "%load ../utils/plot-utils\n",
    "%load ../utils/DataPoints.java\n",
    "%load ../utils/Training.java\n",
    "%load ../utils/Accumulator.java"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ai.djl.basicdataset.cv.classification.*;\n",
    "import org.apache.commons.lang3.ArrayUtils;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To compare against our previous results\n",
    "achieved with (linear) softmax regression\n",
    "(:numref:`sec_softmax_scratch`),\n",
    "we will continue work with \n",
    "the Fashion-MNIST image classification dataset \n",
    "(:numref:`sec_fashion_mnist`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "2"
    }
   },
   "outputs": [],
   "source": [
    "int batchSize = 256;\n",
    "\n",
    "FashionMnist trainIter = FashionMnist.builder()\n",
    "        .optUsage(Dataset.Usage.TRAIN)\n",
    "        .setSampling(batchSize, true)\n",
    "        .optLimit(Long.getLong(\"DATASET_LIMIT\", Long.MAX_VALUE))\n",
    "        .build();\n",
    "\n",
    "\n",
    "FashionMnist testIter = FashionMnist.builder()\n",
    "        .optUsage(Dataset.Usage.TEST)\n",
    "        .setSampling(batchSize, true)\n",
    "        .optLimit(Long.getLong(\"DATASET_LIMIT\", Long.MAX_VALUE))\n",
    "        .build();\n",
    "                            \n",
    "trainIter.prepare();\n",
    "testIter.prepare();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initializing Model Parameters\n",
    "\n",
    "Recall that Fashion-MNIST contains $10$ classes,\n",
    "and that each image consists of a $28 \\times 28 = 784$\n",
    "grid of (black and white) pixel values.\n",
    "Again, we will disregard the spatial structure\n",
    "among the pixels (for now),\n",
    "so we can think of this as simply a classification dataset\n",
    "with $784$ input features and $10$ classes.\n",
    "To begin, we will implement an MLP\n",
    "with one hidden layer and $256$ hidden units.\n",
    "Note that we can regard both of these quantities\n",
    "as *hyperparameters* and ought in general\n",
    "to set them based on performance on validation data.\n",
    "Typically, we choose layer widths in powers of $2$,\n",
    "which tend to be computationally efficient because\n",
    "of how memory is alotted and addressed in hardware.\n",
    "\n",
    "Again, we will represent our parameters with several `NDArray`s.\n",
    "Note that *for every layer*, we must keep track of\n",
    "one weight matrix and one bias vector.\n",
    "As always, we call `attachGradient()` to allocate memory\n",
    "for the gradients (of the loss) with respect to these parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "3"
    }
   },
   "outputs": [],
   "source": [
    "int numInputs = 784;\n",
    "int numOutputs = 10;\n",
    "int numHiddens = 256;\n",
    "\n",
    "NDManager manager = NDManager.newBaseManager();\n",
    "\n",
    "NDArray W1 = manager.randomNormal(\n",
    "                        0, 0.01f, new Shape(numInputs, numHiddens), DataType.FLOAT32);\n",
    "NDArray b1 = manager.zeros(new Shape(numHiddens));\n",
    "NDArray W2 = manager.randomNormal(\n",
    "                        0, 0.01f, new Shape(numHiddens, numOutputs), DataType.FLOAT32);\n",
    "NDArray b2 = manager.zeros(new Shape(numOutputs));\n",
    "\n",
    "NDList params = new NDList(W1, b1, W2, b2);\n",
    "\n",
    "for (NDArray param : params) {\n",
    "    param.setRequiresGradient(true);\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Activation Function\n",
    "\n",
    "To make sure we know how everything works,\n",
    "we will implement the ReLU activation ourselves\n",
    "using the `maximum` function rather than \n",
    "invoking `Activation.relu` directly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "4"
    }
   },
   "outputs": [],
   "source": [
    "public NDArray relu(NDArray X){\n",
    "    return X.maximum(0f);\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The model\n",
    "\n",
    "Because we are disregarding spatial structure, \n",
    "we `reshape` each 2D image into \n",
    "a flat vector of length  `numInputs`.\n",
    "Finally, we implement our model \n",
    "with just a few lines of code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "5"
    }
   },
   "outputs": [],
   "source": [
    "public NDArray net(NDArray X) {\n",
    "    X = X.reshape(new Shape(-1, numInputs));\n",
    "    NDArray H = relu(X.dot(W1).add(b1));\n",
    "    return H.dot(W2).add(b2);\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The Loss Function\n",
    "\n",
    "To ensure numerical stability,\n",
    "and because we already implemented\n",
    "the softmax function from scratch\n",
    "(:numref:`sec_softmax_scratch`),\n",
    "we leverage Gluon's integrated function\n",
    "for calculating the softmax and cross-entropy loss.\n",
    "Recall our earlier discussion of these intricacies \n",
    "(:numref:`sec_mlp`).\n",
    "We encourage the interested reader \n",
    "to examine the source code for `Loss.SoftmaxCrossEntropyLoss`\n",
    "to deepen their knowledge of implementation details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "6"
    }
   },
   "outputs": [],
   "source": [
    "Loss loss = Loss.softmaxCrossEntropyLoss();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training\n",
    "\n",
    "Fortunately, the training loop for MLPs\n",
    "is exactly the same as for softmax regression.\n",
    "\n",
    "We run the training like how we did in Chapter 3, \n",
    "(see :numref:`sec_softmax_scratch`),\n",
    "setting the number of epochs to $10$ \n",
    "and the learning rate to $0.5$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "int numEpochs = Integer.getInteger(\"MAX_EPOCH\", 10);\n",
    "float lr = 0.5f;\n",
    "\n",
    "double[] trainLoss;\n",
    "double[] testAccuracy;\n",
    "double[] epochCount;\n",
    "double[] trainAccuracy;\n",
    "\n",
    "trainLoss = new double[numEpochs];\n",
    "trainAccuracy = new double[numEpochs];\n",
    "testAccuracy = new double[numEpochs];\n",
    "epochCount = new double[numEpochs];"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "7"
    }
   },
   "outputs": [],
   "source": [
    "float epochLoss = 0f;\n",
    "float accuracyVal = 0f;\n",
    "\n",
    "for (int epoch = 1; epoch <= numEpochs; epoch++) {\n",
    "    System.out.print(\"Running epoch \" + epoch + \"...... \");\n",
    "    // Iterate over dataset\n",
    "    for (Batch batch : trainIter.getData(manager)) {\n",
    "\n",
    "        NDArray X = batch.getData().head();\n",
    "        NDArray y = batch.getLabels().head();\n",
    "\n",
    "        try(GradientCollector gc = Engine.getInstance().newGradientCollector()) {\n",
    "            NDArray yHat = net(X); // net function call\n",
    "\n",
    "            NDArray lossValue = loss.evaluate(new NDList(y), new NDList(yHat));\n",
    "            NDArray l = lossValue.mul(batchSize);\n",
    "\n",
    "            accuracyVal += Training.accuracy(yHat, y);\n",
    "            epochLoss += l.sum().getFloat();\n",
    "\n",
    "            gc.backward(l); // gradient calculation\n",
    "        }\n",
    "\n",
    "        batch.close();\n",
    "        Training.sgd(params, lr, batchSize); // updater\n",
    "    }\n",
    "\n",
    "    trainLoss[epoch-1] = epochLoss/trainIter.size();\n",
    "    trainAccuracy[epoch-1] = accuracyVal/trainIter.size();\n",
    "\n",
    "    epochLoss = 0f;\n",
    "    accuracyVal = 0f;    \n",
    "    // testing now\n",
    "    for (Batch batch : testIter.getData(manager)) {\n",
    "\n",
    "        NDArray X = batch.getData().head();\n",
    "        NDArray y = batch.getLabels().head();\n",
    "\n",
    "        NDArray yHat = net(X); // net function call\n",
    "        accuracyVal += Training.accuracy(yHat, y);\n",
    "    }\n",
    "\n",
    "    testAccuracy[epoch-1] = accuracyVal/testIter.size();\n",
    "    epochCount[epoch-1] = epoch;\n",
    "    accuracyVal = 0f;\n",
    "    System.out.println(\"Finished epoch \" + epoch);\n",
    "}\n",
    "\n",
    "System.out.println(\"Finished training!\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length];\n",
    "\n",
    "Arrays.fill(lossLabel, 0, trainLoss.length, \"train loss\");\n",
    "Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, \"train acc\");\n",
    "Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length,\n",
    "                trainLoss.length + testAccuracy.length + trainAccuracy.length, \"test acc\");\n",
    "\n",
    "Table data = Table.create(\"Data\").addColumns(\n",
    "    DoubleColumn.create(\"epochCount\", ArrayUtils.addAll(epochCount, ArrayUtils.addAll(epochCount, epochCount))),\n",
    "    DoubleColumn.create(\"loss\", ArrayUtils.addAll(trainLoss, ArrayUtils.addAll(trainAccuracy, testAccuracy))),\n",
    "    StringColumn.create(\"lossLabel\", lossLabel)\n",
    ");\n",
    "\n",
    "render(LinePlot.create(\"\", data, \"epochCount\", \"loss\", \"lossLabel\"), \"text/html\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "We saw that implementing a simple MLP is easy, \n",
    "even when done manually.\n",
    "That said, with a large number of layers, \n",
    "this can still get messy \n",
    "(e.g., naming and keeping track of our model's parameters, etc).\n",
    "\n",
    "## Exercises\n",
    "\n",
    "1. Change the value of the hyperparameter `numHiddens` and see how this hyperparameter influences your results. Determine the best value of this hyperparameter, keeping all others constant.\n",
    "1. Try adding an additional hidden layer to see how it affects the results.\n",
    "1. How does changing the learning rate alter your results? Fixing the model architecture and other hyperparameters (including number of epochs), what learning rate gives you the best results? \n",
    "1. What is the best result you can get by optimizing over all the parameters (learning rate, iterations, number of hidden layers, number of hidden units per layer) jointly? \n",
    "1. Describe why it is much more challenging to deal with multiple hyperparameters. \n",
    "1. What is the smartest strategy you can think of for structuring a search over multiple hyperparameters?\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
